Packages and utilities

library(tidyverse)
library(lme4)
library(lmerTest)
library(logging)
library(mvtnorm)
library(mgcv)
# Compute the log-likelihood of a new dataset using a fit lme4 model.
logLik_test <- function(lm, test_X, test_y) {
  predictions <- predict(lm, test_X, re.form=NA)
  # Get std.dev. of residual, estimated from train data
  stdev <- sigma(lm)
  # For each prediction--observation, get the density p(obs | N(predicted, model_sigma)) and reduce
  density <- sum(dnorm(test_y, predictions, stdev, log=TRUE))
  return(density)
}
# Get per-prediction log-likelihood
logLik_test_per <- function(lm, test_X, test_y) {
  predictions <- predict(lm, test_X, re.form=NA)
  # Get std.dev. of residual, estimated from train data
  stdev <- sigma(lm)
  # For each prediction--observation, get the density p(obs | N(predicted, model_sigma))
  densities <- dnorm(test_y, predictions, stdev, log=TRUE)
  return(densities)
}
# Compute MSE of a new dataset using a fit lme4 model.
mse_test <- function(lm, test_X, test_y) {
  return(mean((predict(lm, test_X, re.form=NA) - test_y) ^ 2))
}
#Sanity checks
#mylm <- gam(psychometric ~  s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"), data=train_data)
#c(logLik(mylm), logLik_test(mylm, train_data, train_data$psychometric))
#logLik_test(mylm, test_data, test_data$psychometric)

Data loading and preprocessing

data = read.csv("../data/harmonized_results.csv")

all_data = data %>%
  mutate(seed = as.factor(seed)) %>%
  group_by(corpus, model, training, seed) %>%
    mutate(prev_surp = lag(surprisal),
         prev_code = lag(code),
         prev_len = lag(len),
         prev_freq = lag(freq),
         prev_surp = lag(surprisal),
         
         prev2_freq = lag(prev_freq),
         prev2_code = lag(prev_code),
         prev2_len = lag(prev_len),
         prev2_surp = lag(prev_surp),
         
         prev3_freq = lag(prev2_freq),
         prev3_code = lag(prev2_code),
         prev3_len = lag(prev2_len),
         prev3_surp = lag(prev2_surp)) %>%
  ungroup() %>%
  
  # Filter back three for the dundee corpus. Filter back 1 for all other corpora
  # NB this effectively removes all zero-surprisal rows, since early-sentence tokens don't have contiguous token history
  filter((corpus == "dundee" & code == prev2_code + 2) | (corpus != "dundee" & code == prev_code + 1)) %>%
  
  select(-prev_code, -prev2_code, -prev3_code) %>%
  drop_na()

all_data = all_data %>%
  mutate(
    model = as.character(model),
    model = if_else(model == "gpt-2", "gpt2", model),
    model = as.factor(model)) %>% 
  # DEV
  filter(surprisal < 15)
missing_rows = all_data %>% complete(nesting(corpus, code), nesting(model, training, seed)) %>% 
  group_by(corpus, code) %>% 
    filter(sum(is.na(surprisal)) > 0) %>% 
  ungroup() %>% 
  anti_join(all_data, by=c("corpus", "code", "model", "training", "seed"))

missing_rows %>% ggplot(aes(x=corpus, fill=factor(paste(model,training)))) + geom_bar(position=position_dodge(width=0.8))

print(missing_rows %>% group_by(model, training, seed, corpus) %>% summarise(n=n())) %>% arrange(desc(n))

# Compute the ideal number of model--seed--training observations per token.
to_drop = all_data %>%
  group_by(corpus, code) %>% summarise(n = n()) %>% ungroup() %>%
  group_by(corpus) %>% mutate( max_n = max(n)) %>% ungroup() %>%
  filter(max_n != n) %>%
  select(code, corpus)

#to_drop = all_data %>% group_by(corpus, code) %>% filter(n() != ideal_token_obs_count) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop), "observations corresponding to corpus tokens which are missing observations for some model."))
2020-05-08 14:15:05 INFO::Dropping 17953 observations corresponding to corpus tokens which are missing observations for some model.
loginfo(paste("Dropping", to_drop %>% group_by(corpus, code) %>% n_groups(), "tokens which are missing observations for some model."))
2020-05-08 14:15:05 INFO::Dropping 17953 tokens which are missing observations for some model.
all_data = all_data %>% anti_join(to_drop %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
2020-05-08 14:15:05 INFO::After drop, 853258 observations ( 29314  tokens) remain.

to_drop_zero_surps = all_data %>% group_by(corpus, code) %>% filter(any(surprisal == 0)) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop_zero_surps), "observations corresponding to corpus tokens which have surprisal zeros for some model."))
2020-05-08 14:15:06 INFO::Dropping 146 observations corresponding to corpus tokens which have surprisal zeros for some model.
loginfo(paste("Dropping", to_drop_zero_surps %>% group_by(corpus, code) %>% n_groups(), "tokens which have surprisal zeros for some model."))
2020-05-08 14:15:06 INFO::Dropping 5 tokens which have surprisal zeros for some model.
all_data = all_data %>% anti_join(to_drop_zero_surps %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
2020-05-08 14:15:06 INFO::After drop, 853112 observations ( 29309  tokens) remain.

to_drop_zero_psychs = all_data %>% group_by(corpus, code) %>% filter(any(psychometric == 0)) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop_zero_psychs), "observations corresponding to corpus tokens which have surprisal zeros for some model."))
2020-05-08 14:15:06 INFO::Dropping 12934 observations corresponding to corpus tokens which have surprisal zeros for some model.
loginfo(paste("Dropping", to_drop_zero_psychs %>% group_by(corpus, code) %>% n_groups(), "tokens which have surprisal zeros for some model."))
2020-05-08 14:15:06 INFO::Dropping 446 tokens which have surprisal zeros for some model.
all_data = all_data %>% anti_join(to_drop_zero_psychs %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
2020-05-08 14:15:07 INFO::After drop, 840178 observations ( 28863  tokens) remain.

Learn models

# Compute linear model stats for the given training data subset and full test data.
# Automatically subsets the test data to match the relevant group for which we are training a linear model.
get_lm_data <- function(df, test_data, formula, fold, store_env) {
  #this_lm <- gam(formula, data=df);
  this_lm = lm(formula, data=df)
  this_test_data <- semi_join(test_data, df, by=c("training", "model", "seed", "corpus"));
  
  # Save lm to the global env so that we can access residuals later.
  lm_name = paste(unique(paste(df$model, df$training, df$seed, df$corpus))[1], fold)
  assign(lm_name, this_lm, envir=store_env)
  
  summarise(df,
            log_lik = as.numeric(logLik(this_lm, REML = F)),
            test_lik = logLik_test(this_lm, this_test_data, this_test_data$psychometric),
            test_mse = mse_test(this_lm, this_test_data, this_test_data$psychometric))
}
# For a previously fitted lm stored in store_env, get the residuals on test data of the relevant data subset.
get_lm_residuals <- function(df, fold, store_env) {
  # Retrieve the relevant lm.
  lm_name = paste(unique(paste(df$model, df$training, df$seed, df$corpus))[1], fold)
  this_lm <- get(lm_name, envir=store_env)
  
  mutate(df,
         likelihood = logLik_test_per(this_lm, df, df$psychometric),
         resid = df$psychometric - predict(this_lm, df, re.form=NA))
}
# Compute per-example delta-log-likelihood for the given test fold.
get_lm_delta_log_lik <- function(test_data, fold, baseline_env, full_env) {
  lm_name = paste(unique(paste(test_data$model, test_data$training, test_data$seed, test_data$corpus))[1], fold)
  baseline_lm <- get(lm_name, envir=baseline_env)
  full_lm <- get(lm_name, envir=full_env)
  
  delta_log_lik = logLik_test_per(full_lm, test_data, test_data$psychometric) - logLik_test_per(baseline_lm, test_data, test_data$psychometric)
  return(cbind(test_data, delta_log_lik=delta_log_lik))
}
#####
# Define regression formulae.
# Eye-tracking regression: only use surprisal and previous surprisal; SPRT regression: use 2-back features.
#baseline_rt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr")
#baselie_sprt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr")
#full_rt_regression = (psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20)
                     #+ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"))
#full_sprt_regression = (psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + s(prev2_surp, bs = "cr", k = 20)
                        #+ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr"))

baseline_rt_regression = psychometric ~ freq + prev_freq + prev2_freq + prev3_freq + len + prev_len + prev2_len + prev3_len
baseline_sprt_regression = psychometric ~ freq + prev_freq + len + prev_len

full_rt_regression = psychometric ~ surprisal + prev_surp + prev2_surp + prev3_surp + freq + prev_freq + prev2_freq + prev3_freq + len + prev_len + prev2_len + prev3_len
full_sprt_regression = psychometric ~ surprisal + prev_surp + freq + prev_freq + len + prev_len
  
#####
# Prepare frames/environments for storing results/objects.
baseline_results = data.frame()
full_model_results = data.frame()
baseline_residuals = data.frame()
full_residuals = data.frame()
log_lik_deltas = data.frame()

#Randomly shuffle the data
all_data<-all_data[sample(nrow(all_data)),]
#Create K equally size folds
K = 5
folds <- cut(seq(1,nrow(all_data)),breaks=K,labels=FALSE)
#Perform 10 fold cross validation

# Fit models for some fold of the data.
baseline_corpus = function(corpus, df, test_data, fold, env) {
  if(corpus == "dundee") {
    get_lm_data(df, test_data, baseline_rt_regression, fold, env)
  } else {
    get_lm_data(df, test_data, baseline_sprt_regression, fold, env)
  }
}
full_model_corpus = function(corpus, df, test_data, fold, env) {
  if(corpus[1] == "dundee") {
    get_lm_data(df, test_data, full_rt_regression, fold, env)
  } else {
    get_lm_data(df, test_data, full_sprt_regression, fold, env)
  }
}

# Prepare a new Environment in which we store fitted LMs, which we'll query later for residuals and other metrics.
baseline_env = new.env()
full_env = new.env()

for(i in 1:K) { 
  #Segement your data by fold using the which() function 
  testIndexes <- which(folds==i, arr.ind=TRUE)
  test_data <- all_data[testIndexes, ]
  train_data <- all_data[-testIndexes, ]
  
  # Compute a baseline linear model for each model--training--seed--RT-corpus combination.
  baselines = train_data %>%
    group_by(model, training, seed, corpus) %>%
      print(model) %>%
      do(baseline_corpus(unique(.$corpus), ., test_data, i, baseline_env)) %>%
    ungroup() %>%
    mutate(seed = as.factor(seed),
           fold = i)
  
  baseline_results = rbind(baseline_results, baselines)
  
  # Compute a full linear model for each model--training--seed-RT-corpus combination
  full_models = train_data %>%
    group_by(model, training, seed, corpus) %>%
      do(full_model_corpus(unique(.$corpus), ., test_data, i, full_env)) %>%
    ungroup() %>%
    mutate(seed = as.factor(seed),
           fold = i)
  
  full_model_results = rbind(full_model_results, full_models)
  
  # Compute delta-log-likelihoods
  fold_log_lik_deltas = test_data %>%
    group_by(model, training, seed, corpus) %>%
      do(get_lm_delta_log_lik(., i, baseline_env, full_env)) %>%
    ungroup()

  log_lik_deltas = rbind(log_lik_deltas, fold_log_lik_deltas)
  
  fold_baseline_residuals = test_data %>%
    group_by(model, training, seed, corpus) %>%
      do(get_lm_residuals(., i, baseline_env)) %>%
    ungroup()

  baseline_residuals = rbind(baseline_residuals, fold_baseline_residuals)

  fold_full_residuals = test_data %>%
    group_by(model, training, seed, corpus) %>%
      do(get_lm_residuals(., i, full_env)) %>%
    ungroup()

  full_residuals = rbind(full_residuals, fold_full_residuals)
}

|=========================================================================================================================================                                           | 76% ~1 s remaining     
|===========================================================================================================================================                                         | 77% ~1 s remaining     
|=============================================================================================================================================                                       | 78% ~1 s remaining     
|=================================================================================================================================================                                   | 81% ~1 s remaining     
|=====================================================================================================================================================                               | 83% ~0 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|===================================================================================================================================================================                 | 91% ~0 s remaining     
|=======================================================================================================================================================================             | 93% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     
|===========================================================================================================================================                                         | 77% ~1 s remaining     
|=============================================================================================================================================                                       | 78% ~1 s remaining     
|=================================================================================================================================================                                   | 81% ~1 s remaining     
|===================================================================================================================================================                                 | 82% ~0 s remaining     
|=======================================================================================================================================================                             | 84% ~0 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|=============================================================================================================================================================                       | 88% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|=========================================================================================================================================================================           | 94% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     

|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===========================================================================================================================================================                         | 86% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|====================================================================================================================================================================================|100% ~0 s remaining     
|====================================================================================================================================                                                | 74% ~1 s remaining     
|===========================================================================================================================================                                         | 77% ~1 s remaining     
|=============================================================================================================================================                                       | 78% ~1 s remaining     
|===================================================================================================================================================                                 | 82% ~0 s remaining     
|=====================================================================================================================================================                               | 83% ~1 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|===================================================================================================================================================================                 | 91% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     

|=======================================================================================================================================================                             | 84% ~0 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|===================================================================================================================================================================                 | 91% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|=========================================================================================================================================================================           | 94% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|====================================================================================================================================================================================|100% ~0 s remaining     
|===========================================================================================================================================                                         | 77% ~1 s remaining     
|=============================================================================================================================================                                       | 78% ~1 s remaining     
|===================================================================================================================================================                                 | 82% ~0 s remaining     
|=======================================================================================================================================================                             | 84% ~0 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===========================================================================================================================================================                         | 86% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|===================================================================================================================================================================                 | 91% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|=========================================================================================================================================================================           | 94% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     

|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|===================================================================================================================================================================                 | 91% ~0 s remaining     
|=======================================================================================================================================================================             | 93% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     
|===========================================================================================================================================                                         | 77% ~1 s remaining     
|===============================================================================================================================================                                     | 80% ~1 s remaining     
|===================================================================================================================================================                                 | 82% ~0 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|===============================================================================================================================================================================     | 98% ~0 s remaining     
|====================================================================================================================================================================================|100% ~0 s remaining     

|=======================================================================================================================================================                             | 84% ~0 s remaining     
|===========================================================================================================================================================                         | 86% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     
|===================================================================================================================================================                                 | 82% ~0 s remaining     
|=========================================================================================================================================================                           | 85% ~0 s remaining     
|===============================================================================================================================================================                     | 89% ~0 s remaining     
|=====================================================================================================================================================================               | 92% ~0 s remaining     
|===========================================================================================================================================================================         | 95% ~0 s remaining     
|=================================================================================================================================================================================   | 99% ~0 s remaining     
#write.csv(full_residuals, "../data/analysis_checkpoints/full_residuals.csv")
#write.csv(baseline_residuals, "../data/analysis_checkpoints/baseline_residuals.csv")
model_deltas = log_lik_deltas %>%
  group_by(model, training, seed, corpus) %>% 
  summarise(mean_delta_log_lik = mean(delta_log_lik),
            sem_delta_log_lik = sd(delta_log_lik) / sqrt(length(delta_log_lik)))
write.csv(full_model_results, "../data/analysis_checkpoints/full_model_result.csv")
write.csv(baseline_results, "../data/analysis_checkpoints/baseline_results.csv")
#full_model_results = read.csv("../data/analysis_checkpoints/ffull_model_results.csv")
#baseline_results = read.csv("../data/analysis_checkpoints/fbaseline_resultsb.csv")
metric <- "ΔLogLik"
#metric <- "-ΔMSE"

# # Select the relevant metric.
model_deltas = model_deltas %>%
    # Retrieve the current test metric
    mutate(delta_test_mean = mean_delta_log_lik,
           delta_test_sem = sem_delta_log_lik) %>%
    # mutate(delta_test_mean = mean_delta_mse,
    #        delta_test_sem = sem_delta_mse)
    
    # Remove the raw metrics.
    select(-mean_delta_log_lik, -sem_delta_log_lik,
           #-mean_delta_mse, -sem_delta_mse
           )
model_deltas
# Sanity check: training on train+test data should yield improved performance over training on just training data. (When evaluating on test data.)
# full_baselines = all_data %>%
#   group_by(model, training, seed, corpus) %>%
#   summarise(baseline_train_all_test_lik = logLik_test(lm(psychometric ~ len + freq + sent_pos, data=.), semi_join(test_data, ., by=c("training", "model", "seed", "corpus")), semi_join(test_data, ., by=c("training", "model", "seed", "corpus"))$psychometric)) %>%
#   ungroup()
# full_baselines
# 
# full_baselines %>%
#   right_join(baselines, by=c("seed", "training", "model", "corpus")) %>%
#   mutate(delta=baseline_train_all_test_lik-baseline_test_lik) %>%
#   select(-baseline_lik) # %>%
#   #select(-baseline_test_lik, -baseline_train_all_test_lik, -baseline_lik, -baseline_test_mse)

Load language model data (SyntaxGym, PPL)

language_model_data = read.csv("../data/model_metadata.csv") %>%
  mutate(model = as.character(model),
         model = if_else(model == "gpt-2", "gpt2", model),
         model = as.factor(model)) %>%
  mutate(train_size = case_when(str_starts(training, "bllip-lg") ~ 42,
                                str_starts(training, "bllip-md") ~ 15,
                                str_starts(training, "bllip-sm") ~ 5,
                                str_starts(training, "bllip-xs") ~ 1),
         
         # Training vocabulary usually covaries with the training corpus.
         # But BPE models share a vocabulary across training corpora.
         training_vocab=as.factor(ifelse(str_detect(training, "gptbpe"), "gptbpe", as.character(training))),
         training_source=as.factor(str_replace(as.character(training), "-gptbpe", ""))
         ) %>%
  mutate(seed = as.factor(seed)) %>%
  select(-pid, -test_loss) %>%
  distinct(model, training, seed, .keep_all = TRUE)
table(language_model_data$seed)

         0        111        120        922       1111       3602       4301       7245       7877      28066      28068      44862      51272      64924 1581807512 1581807578 1581861474 1581955288 
         4          7          6          5          4          1          1          1          1          1          1          1          1          1          1          1          1          1 
1582126320 1586986276 1587139950 
         1          1          1 
table(model_deltas$seed)

       111        120        607        922       1111       3602       4301       7245       7877      28066      28068      44862      51272      64924 1581807512 1581807578 1581861474 1581955288 
         9          9          1          9         12          3          3          3          3          3          3          3          3          3          3          3          3          3 
1582126320 1586986276 1587139950 
         3          3          3 

First join delta-metric data with model auxiliary data.

model_deltas = model_deltas %>%
  merge(language_model_data, by = c("seed", "training", "model"), all=T) %>%
  drop_na()

model_deltas

Also join on the original linear model data, rather than collapsing to delta-metrics. This will support regressions later on that don’t collapse across folds.

Final data preprocessing

# Exclude ordered-neurons from all analyses.
model_deltas <- model_deltas %>%
  filter(model != "ordered-neurons")

Visualizations

The basics

all_data %>% ggplot(aes(x=corpus)) + geom_bar()

print(all_data %>% group_by(corpus) %>% summarise(n=n()))
all_data %>% 
  ggplot(aes(x=freq, color=corpus)) + geom_density()

all_data %>% 
  ggplot(aes(x=len, color=corpus)) + geom_density()

all_data %>% 
  ggplot(aes(x=surprisal, color=corpus)) + geom_density()

Predictive power and SG

model_deltas %>%
  ggplot(aes(x=sg_score, y=delta_test_mean)) +
    geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem)) +
    geom_smooth(method="lm", se=T) +
    geom_point(stat="identity", position="dodge", alpha=1, size=3, aes(color=training_vocab, shape=model)) +
    ylab(metric) +
    xlab("Syntax Generalization Score") +
    ggtitle("Syntactic Generalization vs. Predictive Power") +
    scale_color_manual(values = c("bllip-lg"="#440154FF",
                              "bllip-md"="#39568CFF",
                              "bllip-sm"="#1F968BFF",
                              "bllip-xs"="#73D055FF",
                              "gptbpe"="#888888")) +
    facet_grid(~corpus, scales="free") +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "bottom")

#ggsave("./cogsci_images/sg_loglik.png",height=5,width=6)

Regression analyses

We control for effects of perplexity by relating the residuals of a performance ~ PPL regression to SG score.

# Prepare a residualized regression for x1 onto y, controlling for the effects of x2.
d_resid = model_deltas %>%
  drop_na() %>%
  
  # Residualize delta metric w.r.t PPL for each model--training--seed
  group_by(corpus) %>%
    mutate(resid.delta = resid(lm(delta_test_mean ~ training:test_ppl))) %>%
  ungroup() %>%
  
  # Residualize SG score w.r.t. PPL for each training vocabulary
  group_by(training_vocab) %>%
    # NB no need for training:ppl interaction, since we're within-group.
    mutate(resid.sg = resid(lm(sg_score ~ test_ppl))) %>%
  ungroup() %>%
  # Compute summary statistics across model--training--seed--corpus.
  group_by(model, training_vocab, corpus, seed) %>%
    summarise(resid.delta.mean = mean(resid.delta),
              resid.delta.sem = sd(resid.delta) / sqrt(length(resid.delta)),
              resid.sg.mean = mean(resid.sg),
              resid.sg.sem = sd(resid.sg) / sqrt(length(resid.sg)))
# Now plot residual vs SG
d_resid %>%
  #filter(corpus != "bnc-brown") %>%
  ggplot(aes(x=resid.sg.mean, y=resid.delta.mean)) +
    geom_errorbar(aes(xmin=resid.sg.mean - resid.sg.sem,
                      xmax=resid.sg.mean + resid.sg.sem,
                      ymin=resid.delta.mean - resid.delta.sem,
                      ymax=resid.delta.mean + resid.delta.sem), alpha=0.3) +
    geom_smooth(method="lm", se=T) +
    geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model, color=training_vocab)) +
    ylab(paste("Residual", metric)) +
    xlab("Residual Syntax Generalization Score") +
    ggtitle("Syntactic Generalization vs. Predictive Power") +
    scale_color_manual(values = c("bllip-lg"="#440154FF",
                                  "bllip-md"="#39568CFF",
                                  "bllip-sm"="#1F968BFF",
                                  "bllip-xs"="#73D055FF",
                                  "gptbpe"="#888888")) +
    facet_grid(.~corpus, scales="free") +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "right")
Ignoring unknown aesthetics: xmin, xmax
ggsave("../images/cuny2020/ppl_sg.png",height=4.5,width=11)

do_stepwise_regression = function(cur_corpus) {
  regression_data = model_deltas %>%
    filter(corpus == cur_corpus)
  
  print("----------------------")
  print(cur_corpus)
  
  lm1 = lm(delta_test_mean ~ training_vocab:test_ppl, data = regression_data)
  lm2 = lm(delta_test_mean ~ training_vocab:test_ppl + sg_score, data = regression_data)
  print(anova(lm1, lm2))
  summary(lm2)
}
#do_stepwise_regression("bnc-brown")
do_stepwise_regression("dundee")
[1] "----------------------"
[1] "dundee"
Analysis of Variance Table

Model 1: delta_test_mean ~ training_vocab:test_ppl
Model 2: delta_test_mean ~ training_vocab:test_ppl + sg_score
  Res.Df        RSS Df  Sum of Sq      F Pr(>F)
1     23 6.3436e-05                            
2     22 6.2294e-05  1 1.1421e-06 0.4034 0.5319

Call:
lm(formula = delta_test_mean ~ training_vocab:test_ppl + sg_score, 
    data = regression_data)

Residuals:
       Min         1Q     Median         3Q        Max 
-0.0022638 -0.0009928 -0.0004429  0.0008255  0.0033920 

Coefficients:
                                  Estimate Std. Error t value Pr(>|t|)    
(Intercept)                      8.119e-03  1.848e-03   4.393 0.000231 ***
sg_score                        -1.672e-03  2.632e-03  -0.635 0.531908    
training_vocabbllip-lg:test_ppl -1.742e-05  1.712e-05  -1.017 0.320187    
training_vocabbllip-md:test_ppl -2.018e-05  1.427e-05  -1.414 0.171283    
training_vocabbllip-sm:test_ppl -2.708e-05  1.271e-05  -2.131 0.044507 *  
training_vocabbllip-xs:test_ppl -2.602e-05  8.479e-06  -3.069 0.005621 ** 
training_vocabgptbpe:test_ppl   -6.339e-06  4.166e-06  -1.522 0.142331    
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 0.001683 on 22 degrees of freedom
Multiple R-squared:  0.3827,    Adjusted R-squared:  0.2144 
F-statistic: 2.274 on 6 and 22 DF,  p-value: 0.07374
do_stepwise_regression("natural-stories")
[1] "----------------------"
[1] "natural-stories"
Analysis of Variance Table

Model 1: delta_test_mean ~ training_vocab:test_ppl
Model 2: delta_test_mean ~ training_vocab:test_ppl + sg_score
  Res.Df        RSS Df  Sum of Sq      F  Pr(>F)  
1     23 1.8197e-05                               
2     22 1.5826e-05  1 2.3704e-06 3.2951 0.08314 .
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Call:
lm(formula = delta_test_mean ~ training_vocab:test_ppl + sg_score, 
    data = regression_data)

Residuals:
       Min         1Q     Median         3Q        Max 
-0.0012043 -0.0004901 -0.0001336  0.0004903  0.0014389 

Coefficients:
                                  Estimate Std. Error t value Pr(>|t|)   
(Intercept)                      3.326e-03  9.316e-04   3.570  0.00171 **
sg_score                         2.408e-03  1.327e-03   1.815  0.08314 . 
training_vocabbllip-lg:test_ppl -1.498e-05  8.631e-06  -1.735  0.09670 . 
training_vocabbllip-md:test_ppl -1.951e-05  7.193e-06  -2.713  0.01271 * 
training_vocabbllip-sm:test_ppl -2.100e-05  6.406e-06  -3.279  0.00343 **
training_vocabbllip-xs:test_ppl -1.340e-05  4.274e-06  -3.136  0.00481 **
training_vocabgptbpe:test_ppl   -2.854e-06  2.100e-06  -1.359  0.18787   
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 0.0008482 on 22 degrees of freedom
Multiple R-squared:  0.6531,    Adjusted R-squared:  0.5585 
F-statistic: 6.904 on 6 and 22 DF,  p-value: 0.0003178

Predictive power and perplexity

model_deltas %>%
  ggplot(aes(x=test_ppl, y=delta_test_mean, color=training_vocab)) +
    geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), alpha=0.4) +
    #geom_smooth(method="lm", se=F) +
    geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model)) +
    ylab(metric) +
    xlab("Test Perplexity") +
    #coord_cartesian(ylim = c(1, 16)) +
    ggtitle("Test Perplexity vs. Predictive Power") +
    scale_color_manual(values = c("bllip-lg"="#440154FF",
                                  "bllip-md"="#39568CFF",
                                  "bllip-sm"="#1F968BFF",
                                  "bllip-xs"="#73D055FF",
                                  "gptbpe"="#888888")) +
    facet_grid(~corpus, scales="free") +
    #coord_cartesian(ylim = c(0, 150)) +
    theme(axis.text=element_text(size=12),
          strip.text.x = element_text(size=12),
          legend.text=element_text(size=12),
          axis.title=element_text(size=12),
          legend.position = "right")
ggsave("../images/cuny2020/ppl_loglik.png",height=4.5,width=11)

Effect of training data size

model_deltas %>%
  mutate(train_size = log(train_size)) %>%
  ggplot(aes(x=train_size, y=delta_test_mean, color=model)) +
    geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), width = 0.1) +
    geom_smooth(method="lm", se=T, alpha=0.5) +
    geom_point(stat="identity", position="dodge", alpha=1, size=3) +
    ylab(metric) +
    xlab("Log Million Training Tokens") +
    ggtitle("Training Size vs. Predictive Power") +
    facet_grid(corpus~model, scales="free") +
    #scale_color_manual(values = c("#A42EF1", "#3894C8")) +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "bottom")

#ggsave("./cogsci_images/training_loglik.png",height=5,width=6)
model_deltas %>%
  mutate(train_size = log(train_size)) %>%
  ggplot(aes(x=train_size, y=sg_score, color=model)) +
    geom_smooth(method="lm", se=T, alpha=0.5) +
    geom_point(stat="identity", position="dodge", alpha=1, size=3) +
    ylab("SG SCore") +
    xlab("Log Million Training Tokens") +
    ggtitle("Training Size vs. Syntactic Generalization") +
    #scale_color_manual(values = c("#A42EF1", "#3894C8")) +
    facet_grid(~model, scales="free") +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "bottom")

#ggsave("./cogsci_images/training_sg.png",height=5,width=6)

Smith & Levy reproduction

all_data %>%
  #filter(surprisal < 15, surprisal > 0) %>%
  mutate(bpe=str_detect(training, "bpe"),
         training_source=str_replace(training, "-gptbpe", "")) %>% 
  ggplot(aes(x=surprisal, y=psychometric, color=training_source, linetype=bpe)) +
    stat_smooth(se=T, alpha=0.5) +
    #geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
    #geom_point(stat="identity", position="dodge", alpha=1, size=3) +
    ylab("Processing Time (ms)") +
    xlab("Surprisal (bits)") +
    ggtitle("Surprisal vs. Reading Time / Gaze Duration") +
    facet_grid(corpus~model, scales = "free") +
    # scale_color_manual(values = c("bllip-lg"="#440154FF",
    #                           "bllip-md"="#39568CFF",
    #                           "bllip-sm"="#1F968BFF",
    #                           "bllip-xs"="#73D055FF",
    #                           "bllip-lg-gptbpe"="#888888",
    #                           "bllip-md-gptbpe"="#888888",
    #                           "bllip-sm-gptbpe"="#888888",
    #                           "bllip-xs-gptbpe"="#888888")) +
    theme(axis.text=element_text(size=14),
          axis.text.y = element_text(size = 10),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "right")
ggsave("../images/cuny2020/surp_corr.png",height=4.5,width=12)

Investigate vanilla

all_data %>%
  #filter(surprisal < 15, surprisal > 0) %>%
  filter(model == "vanilla") %>% 
  ggplot(aes(x=surprisal, y=psychometric)) +
    #stat_smooth(se=T, alpha=0.5) +
    #geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
    geom_point(alpha=0.1) + #stat="identity", position="dodge", alpha=1, size=3) +
    ylab("Processing Time (ms)") +
    xlab("Surprisal (bits)") +
    ggtitle("Surprisal vs. Reading Time / Gaze Duration: Vanilla") +
    facet_grid(corpus~training, scales = "free")

    # scale_color_manual(values = c("bllip-lg"="#440154FF",
    #                           "bllip-md"="#39568CFF",
    #                           "bllip-sm"="#1F968BFF",
    #                           "bllip-xs"="#73D055FF",
    #                           "bllip-lg-gptbpe"="#888888",
    #                           "bllip-md-gptbpe"="#888888",
    #                           "bllip-sm-gptbpe"="#888888",
    #                           "bllip-xs-gptbpe"="#888888"))
all_data %>% 
  filter(corpus == "dundee", model == "vanilla", training == "bllip-lg", surprisal > 20, psychometric < 300)
print(full_residuals %>% filter(corpus == "dundee", model == "vanilla", training == "bllip-lg") %>% arrange(desc(resid)))
full_residuals %>% filter(corpus == "dundee", model == "vanilla", training == "bllip-lg") %>% arrange(desc(resid)) %>% filter(resid > 150) %>% 
  ggplot(aes(x=surprisal)) + geom_density()

Investigate RNNG

all_data %>%
  #filter(surprisal < 15, surprisal > 0) %>%
  filter(model == "rnng") %>% 
  ggplot(aes(x=surprisal, y=psychometric)) +
    #stat_smooth(se=T, alpha=0.5) +
    #geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
    geom_point(alpha=0.1) + #stat="identity", position="dodge", alpha=1, size=3) +
    ylab("Processing Time (ms)") +
    xlab("Surprisal (bits)") +
    ggtitle("Surprisal vs. Reading Time / Gaze Duration: RNNG") +
    facet_grid(corpus~training, scales = "free")

all_data %>% 
  filter(corpus == "dundee", model == "rnng", training == "bllip-lg", surprisal > 20, psychometric < 300)
print(full_residuals %>% filter(corpus == "dundee", model == "rnng", training == "bllip-lg") %>% arrange(desc(resid)))
full_residuals %>% filter(corpus == "dundee", model == "rnng", training == "bllip-lg") %>% arrange(desc(resid)) %>% filter(resid > 150) %>% 
  ggplot(aes(x=surprisal)) + geom_density()

Investigate ngram vs vanilla

ngram_resids = full_residuals %>% filter(model == "5gram", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
vanilla_resids = full_residuals %>% filter(model == "vanilla", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
resids_joined = ngram_resids %>% left_join(vanilla_resids, by=c("corpus", "code"), suffix=c(".ngram", ".vanilla"))

resids_joined %>% 
  ggplot(aes(x=resid.ngram, y=resid.vanilla)) + geom_point() + geom_abline(slope=1, color="red") +
  facet_grid(~corpus)


resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.ngram - resid.vanilla)) %>% 
  ggplot(aes(x=resid_abs_diff)) + geom_density() +
  facet_grid(~corpus)


resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.ngram - resid.vanilla)) %>% 
  ggplot(aes(x=freq.ngram, y=resid_abs_diff)) + geom_point(alpha=0.1) + geom_smooth()

Investigate gptbpe vs vanilla

gpt_resids = full_residuals %>% filter(model == "gpt2", training == "bllip-sm-gptbpe") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
vanilla_resids = full_residuals %>% filter(model == "vanilla", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
resids_joined = gpt_resids %>% left_join(vanilla_resids, by=c("corpus", "code"), suffix=c(".gpt", ".vanilla"))

resids_joined %>% 
  ggplot(aes(x=resid.gpt, y=resid.vanilla)) + geom_point() + geom_abline(slope=1, color="red") +
  facet_grid(~corpus)


resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.gpt - resid.vanilla)) %>% 
  ggplot(aes(x=resid_abs_diff)) + geom_density() +
  facet_grid(~corpus)


resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.gpt - resid.vanilla)) %>% 
  ggplot(aes(x=freq.gpt, y=resid_abs_diff)) + geom_point(alpha=0.1) + geom_smooth()

Investigate residuals overall

full_residuals %>% right_join(baseline_residuals, by=c("corpus", "code", "model", "training", "seed"), suffix=c(".full", ".baseline")) %>% 
  mutate(resid_delta=resid.full - resid.baseline,
         training_source=as.factor(str_replace(training, "-gptbpe", "")),
         bpe=str_detect(training, "gptbpe")) %>% 
  ggplot(aes(x=surprisal.full, y=resid_delta, color=training)) +
    facet_grid(model~corpus) +
    geom_point(alpha=0.1, size=0.5)

language_model_data %>% filter(model == "gpt2")
---
title: "CUNY 2020 Analysis"
output: html_notebook
---

# Packages and utilities

```{r}
library(tidyverse)
library(lme4)
library(lmerTest)
library(logging)
library(mvtnorm)
library(mgcv)
```

```{r}
# Compute the log-likelihood of a new dataset using a fit lme4 model.
logLik_test <- function(lm, test_X, test_y) {
  predictions <- predict(lm, test_X, re.form=NA)
  # Get std.dev. of residual, estimated from train data
  stdev <- sigma(lm)
  # For each prediction--observation, get the density p(obs | N(predicted, model_sigma)) and reduce
  density <- sum(dnorm(test_y, predictions, stdev, log=TRUE))
  return(density)
}
# Get per-prediction log-likelihood
logLik_test_per <- function(lm, test_X, test_y) {
  predictions <- predict(lm, test_X, re.form=NA)
  # Get std.dev. of residual, estimated from train data
  stdev <- sigma(lm)
  # For each prediction--observation, get the density p(obs | N(predicted, model_sigma))
  densities <- dnorm(test_y, predictions, stdev, log=TRUE)
  return(densities)
}
# Compute MSE of a new dataset using a fit lme4 model.
mse_test <- function(lm, test_X, test_y) {
  return(mean((predict(lm, test_X, re.form=NA) - test_y) ^ 2))
}
#Sanity checks
#mylm <- gam(psychometric ~  s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"), data=train_data)
#c(logLik(mylm), logLik_test(mylm, train_data, train_data$psychometric))
#logLik_test(mylm, test_data, test_data$psychometric)
```

# Data loading and preprocessing

```{r Load and preprocess data}
data = read.csv("../data/harmonized_results.csv")

all_data = data %>%
  mutate(seed = as.factor(seed)) %>%
  group_by(corpus, model, training, seed) %>%
    mutate(prev_surp = lag(surprisal),
         prev_code = lag(code),
         prev_len = lag(len),
         prev_freq = lag(freq),
         prev_surp = lag(surprisal),
         
         prev2_freq = lag(prev_freq),
         prev2_code = lag(prev_code),
         prev2_len = lag(prev_len),
         prev2_surp = lag(prev_surp),
         
         prev3_freq = lag(prev2_freq),
         prev3_code = lag(prev2_code),
         prev3_len = lag(prev2_len),
         prev3_surp = lag(prev2_surp)) %>%
  ungroup() %>%
  
  # Filter back three for the dundee corpus. Filter back 1 for all other corpora
  # NB this effectively removes all zero-surprisal rows, since early-sentence tokens don't have contiguous token history
  filter((corpus == "dundee" & code == prev2_code + 2) | (corpus != "dundee" & code == prev_code + 1)) %>%
  
  select(-prev_code, -prev2_code, -prev3_code) %>%
  drop_na()

all_data = all_data %>%
  mutate(
    model = as.character(model),
    model = if_else(model == "gpt-2", "gpt2", model),
    model = as.factor(model)) %>% 
  # DEV
  filter(surprisal < 15)
```

```{r}
missing_rows = all_data %>% complete(nesting(corpus, code), nesting(model, training, seed)) %>% 
  group_by(corpus, code) %>% 
    filter(sum(is.na(surprisal)) > 0) %>% 
  ungroup() %>% 
  anti_join(all_data, by=c("corpus", "code", "model", "training", "seed"))

missing_rows %>% ggplot(aes(x=corpus, fill=factor(paste(model,training)))) + geom_bar(position=position_dodge(width=0.8))
print(missing_rows %>% group_by(model, training, seed, corpus) %>% summarise(n=n())) %>% arrange(desc(n))
```


```{r Drop tokens for which any model is missing surprisal data.}

# Compute the ideal number of model--seed--training observations per token.
to_drop = all_data %>%
  group_by(corpus, code) %>% summarise(n = n()) %>% ungroup() %>%
  group_by(corpus) %>% mutate( max_n = max(n)) %>% ungroup() %>%
  filter(max_n != n) %>%
  select(code, corpus)

#to_drop = all_data %>% group_by(corpus, code) %>% filter(n() != ideal_token_obs_count) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop), "observations corresponding to corpus tokens which are missing observations for some model."))
loginfo(paste("Dropping", to_drop %>% group_by(corpus, code) %>% n_groups(), "tokens which are missing observations for some model."))

all_data = all_data %>% anti_join(to_drop %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
```

```{r Drop tokens for which any model has zero-valued surprisals.}

to_drop_zero_surps = all_data %>% group_by(corpus, code) %>% filter(any(surprisal == 0)) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop_zero_surps), "observations corresponding to corpus tokens which have surprisal zeros for some model."))
loginfo(paste("Dropping", to_drop_zero_surps %>% group_by(corpus, code) %>% n_groups(), "tokens which have surprisal zeros for some model."))

all_data = all_data %>% anti_join(to_drop_zero_surps %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
```

```{r Drop tokens for which we have zero-valued psychometric data.}

to_drop_zero_psychs = all_data %>% group_by(corpus, code) %>% filter(any(psychometric == 0)) %>% ungroup()
loginfo(paste("Dropping", nrow(to_drop_zero_psychs), "observations corresponding to corpus tokens which have surprisal zeros for some model."))
loginfo(paste("Dropping", to_drop_zero_psychs %>% group_by(corpus, code) %>% n_groups(), "tokens which have surprisal zeros for some model."))

all_data = all_data %>% anti_join(to_drop_zero_psychs %>% group_by(corpus, code), by=c("corpus", "code"))
loginfo(paste("After drop,", nrow(all_data), "observations (", all_data %>% group_by(corpus, code) %>% n_groups(), " tokens) remain."))
```

 
# Learn models
 
```{r}
# Compute linear model stats for the given training data subset and full test data.
# Automatically subsets the test data to match the relevant group for which we are training a linear model.
get_lm_data <- function(df, test_data, formula, fold, store_env) {
  #this_lm <- gam(formula, data=df);
  this_lm = lm(formula, data=df)
  this_test_data <- semi_join(test_data, df, by=c("training", "model", "seed", "corpus"));
  
  # Save lm to the global env so that we can access residuals later.
  lm_name = paste(unique(paste(df$model, df$training, df$seed, df$corpus))[1], fold)
  assign(lm_name, this_lm, envir=store_env)
  
  summarise(df,
            log_lik = as.numeric(logLik(this_lm, REML = F)),
            test_lik = logLik_test(this_lm, this_test_data, this_test_data$psychometric),
            test_mse = mse_test(this_lm, this_test_data, this_test_data$psychometric))
}
# For a previously fitted lm stored in store_env, get the residuals on test data of the relevant data subset.
get_lm_residuals <- function(df, fold, store_env) {
  # Retrieve the relevant lm.
  lm_name = paste(unique(paste(df$model, df$training, df$seed, df$corpus))[1], fold)
  this_lm <- get(lm_name, envir=store_env)
  
  mutate(df,
         likelihood = logLik_test_per(this_lm, df, df$psychometric),
         resid = df$psychometric - predict(this_lm, df, re.form=NA))
}
# Compute per-example delta-log-likelihood for the given test fold.
get_lm_delta_log_lik <- function(test_data, fold, baseline_env, full_env) {
  lm_name = paste(unique(paste(test_data$model, test_data$training, test_data$seed, test_data$corpus))[1], fold)
  baseline_lm <- get(lm_name, envir=baseline_env)
  full_lm <- get(lm_name, envir=full_env)
  
  delta_log_lik = logLik_test_per(full_lm, test_data, test_data$psychometric) - logLik_test_per(baseline_lm, test_data, test_data$psychometric)
  return(cbind(test_data, delta_log_lik=delta_log_lik))
}
#####
# Define regression formulae.
# Eye-tracking regression: only use surprisal and previous surprisal; SPRT regression: use 2-back features.
#baseline_rt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr")
#baselie_sprt_regression = psychometric ~ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr")
#full_rt_regression = (psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20)
                     #+ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr"))
#full_sprt_regression = (psychometric ~ s(surprisal, bs = "cr", k = 20) + s(prev_surp, bs = "cr", k = 20) + s(prev2_surp, bs = "cr", k = 20)
                        #+ te(freq, len, bs = "cr") + te(prev_freq, prev_len, bs = "cr") + te(prev2_freq, prev2_len, bs = "cr"))

baseline_rt_regression = psychometric ~ freq + prev_freq + prev2_freq + prev3_freq + len + prev_len + prev2_len + prev3_len
baseline_sprt_regression = psychometric ~ freq + prev_freq + len + prev_len

full_rt_regression = psychometric ~ surprisal + prev_surp + prev2_surp + prev3_surp + freq + prev_freq + prev2_freq + prev3_freq + len + prev_len + prev2_len + prev3_len
full_sprt_regression = psychometric ~ surprisal + prev_surp + freq + prev_freq + len + prev_len
  
#####
# Prepare frames/environments for storing results/objects.
baseline_results = data.frame()
full_model_results = data.frame()
baseline_residuals = data.frame()
full_residuals = data.frame()
log_lik_deltas = data.frame()

#Randomly shuffle the data
all_data<-all_data[sample(nrow(all_data)),]
#Create K equally size folds
K = 5
folds <- cut(seq(1,nrow(all_data)),breaks=K,labels=FALSE)
#Perform 10 fold cross validation

# Fit models for some fold of the data.
baseline_corpus = function(corpus, df, test_data, fold, env) {
  if(corpus == "dundee") {
    get_lm_data(df, test_data, baseline_rt_regression, fold, env)
  } else {
    get_lm_data(df, test_data, baseline_sprt_regression, fold, env)
  }
}
full_model_corpus = function(corpus, df, test_data, fold, env) {
  if(corpus[1] == "dundee") {
    get_lm_data(df, test_data, full_rt_regression, fold, env)
  } else {
    get_lm_data(df, test_data, full_sprt_regression, fold, env)
  }
}

# Prepare a new Environment in which we store fitted LMs, which we'll query later for residuals and other metrics.
baseline_env = new.env()
full_env = new.env()

for(i in 1:K) { 
  #Segement your data by fold using the which() function 
  testIndexes <- which(folds==i, arr.ind=TRUE)
  test_data <- all_data[testIndexes, ]
  train_data <- all_data[-testIndexes, ]
  
  # Compute a baseline linear model for each model--training--seed--RT-corpus combination.
  baselines = train_data %>%
    group_by(model, training, seed, corpus) %>%
      print(model) %>%
      do(baseline_corpus(unique(.$corpus), ., test_data, i, baseline_env)) %>%
    ungroup() %>%
    mutate(seed = as.factor(seed),
           fold = i)
  
  baseline_results = rbind(baseline_results, baselines)
  
  # Compute a full linear model for each model--training--seed-RT-corpus combination
  full_models = train_data %>%
    group_by(model, training, seed, corpus) %>%
      do(full_model_corpus(unique(.$corpus), ., test_data, i, full_env)) %>%
    ungroup() %>%
    mutate(seed = as.factor(seed),
           fold = i)
  
  full_model_results = rbind(full_model_results, full_models)
  
  # Compute delta-log-likelihoods
  fold_log_lik_deltas = test_data %>%
    group_by(model, training, seed, corpus) %>%
      do(get_lm_delta_log_lik(., i, baseline_env, full_env)) %>%
    ungroup()

  log_lik_deltas = rbind(log_lik_deltas, fold_log_lik_deltas)
  
  fold_baseline_residuals = test_data %>%
    group_by(model, training, seed, corpus) %>%
      do(get_lm_residuals(., i, baseline_env)) %>%
    ungroup()

  baseline_residuals = rbind(baseline_residuals, fold_baseline_residuals)

  fold_full_residuals = test_data %>%
    group_by(model, training, seed, corpus) %>%
      do(get_lm_residuals(., i, full_env)) %>%
    ungroup()

  full_residuals = rbind(full_residuals, fold_full_residuals)
}
```

```{r}
#write.csv(full_residuals, "../data/analysis_checkpoints/full_residuals.csv")
#write.csv(baseline_residuals, "../data/analysis_checkpoints/baseline_residuals.csv")
```

```{r}
model_deltas = log_lik_deltas %>%
  group_by(model, training, seed, corpus) %>% 
  summarise(mean_delta_log_lik = mean(delta_log_lik),
            sem_delta_log_lik = sd(delta_log_lik) / sqrt(length(delta_log_lik)))
```

```{r}
write.csv(full_model_results, "../data/analysis_checkpoints/full_model_result.csv")
write.csv(baseline_results, "../data/analysis_checkpoints/baseline_results.csv")
#full_model_results = read.csv("../data/analysis_checkpoints/ffull_model_results.csv")
#baseline_results = read.csv("../data/analysis_checkpoints/fbaseline_resultsb.csv")
```

```{r}
metric <- "ΔLogLik"
#metric <- "-ΔMSE"

# # Select the relevant metric.
model_deltas = model_deltas %>%
    # Retrieve the current test metric
    mutate(delta_test_mean = mean_delta_log_lik,
           delta_test_sem = sem_delta_log_lik) %>%
    # mutate(delta_test_mean = mean_delta_mse,
    #        delta_test_sem = sem_delta_mse)
    
    # Remove the raw metrics.
    select(-mean_delta_log_lik, -sem_delta_log_lik,
           #-mean_delta_mse, -sem_delta_mse
           )
model_deltas
```

```{r}
# Sanity check: training on train+test data should yield improved performance over training on just training data. (When evaluating on test data.)
# full_baselines = all_data %>%
#   group_by(model, training, seed, corpus) %>%
#   summarise(baseline_train_all_test_lik = logLik_test(lm(psychometric ~ len + freq + sent_pos, data=.), semi_join(test_data, ., by=c("training", "model", "seed", "corpus")), semi_join(test_data, ., by=c("training", "model", "seed", "corpus"))$psychometric)) %>%
#   ungroup()
# full_baselines
# 
# full_baselines %>%
#   right_join(baselines, by=c("seed", "training", "model", "corpus")) %>%
#   mutate(delta=baseline_train_all_test_lik-baseline_test_lik) %>%
#   select(-baseline_lik) # %>%
#   #select(-baseline_test_lik, -baseline_train_all_test_lik, -baseline_lik, -baseline_test_mse)
```

# Load language model data (SyntaxGym, PPL)

```{r}
language_model_data = read.csv("../data/model_metadata.csv") %>%
  mutate(model = as.character(model),
         model = if_else(model == "gpt-2", "gpt2", model),
         model = as.factor(model)) %>%
  mutate(train_size = case_when(str_starts(training, "bllip-lg") ~ 42,
                                str_starts(training, "bllip-md") ~ 15,
                                str_starts(training, "bllip-sm") ~ 5,
                                str_starts(training, "bllip-xs") ~ 1),
         
         # Training vocabulary usually covaries with the training corpus.
         # But BPE models share a vocabulary across training corpora.
         training_vocab=as.factor(ifelse(str_detect(training, "gptbpe"), "gptbpe", as.character(training))),
         training_source=as.factor(str_replace(as.character(training), "-gptbpe", ""))
         ) %>%
  mutate(seed = as.factor(seed)) %>%
  select(-pid, -test_loss) %>%
  distinct(model, training, seed, .keep_all = TRUE)
table(language_model_data$seed)
table(model_deltas$seed)
```

First join delta-metric data with model auxiliary data.

```{r}
model_deltas = model_deltas %>%
  merge(language_model_data, by = c("seed", "training", "model"), all=T) %>%
  drop_na()

model_deltas
```

Also join on the original linear model data, rather than collapsing to delta-metrics.
This will support regressions later on that don't collapse across folds.


# Final data preprocessing

```{r Filter models and/or corpora}
# Exclude ordered-neurons from all analyses.
model_deltas <- model_deltas %>%
  filter(model != "ordered-neurons")
```


# Visualizations

## The basics

```{r, fig.cap="Corpus sizes"}
all_data %>% ggplot(aes(x=corpus)) + geom_bar()
print(all_data %>% group_by(corpus) %>% summarise(n=n()))
```


```{r, fig.cap="Word frequency distribution by corpus"}
all_data %>% 
  ggplot(aes(x=freq, color=corpus)) + geom_density()
```

```{r, fig.cap="Word length distribution by corpus"}
all_data %>% 
  ggplot(aes(x=len, color=corpus)) + geom_density()
```

```{r, fig.cap="Surprisal distribution by corpus"}
all_data %>% 
  ggplot(aes(x=surprisal, color=corpus)) + geom_density()
```

## Predictive power and SG


```{r By model}
model_deltas %>%
  ggplot(aes(x=sg_score, y=delta_test_mean)) +
    geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem)) +
    geom_smooth(method="lm", se=T) +
    geom_point(stat="identity", position="dodge", alpha=1, size=3, aes(color=training_vocab, shape=model)) +
    ylab(metric) +
    xlab("Syntax Generalization Score") +
    ggtitle("Syntactic Generalization vs. Predictive Power") +
    scale_color_manual(values = c("bllip-lg"="#440154FF",
                              "bllip-md"="#39568CFF",
                              "bllip-sm"="#1F968BFF",
                              "bllip-xs"="#73D055FF",
                              "gptbpe"="#888888")) +
    facet_grid(~corpus, scales="free") +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "bottom")
#ggsave("./cogsci_images/sg_loglik.png",height=5,width=6)
```

### Regression analyses

We control for effects of perplexity by relating the residuals of a `performance ~ PPL` regression to SG score.

```{r Residualized regression}
# Prepare a residualized regression for x1 onto y, controlling for the effects of x2.
d_resid = model_deltas %>%
  drop_na() %>%
  
  # Residualize delta metric w.r.t PPL for each model--training--seed
  group_by(corpus) %>%
    mutate(resid.delta = resid(lm(delta_test_mean ~ training:test_ppl))) %>%
  ungroup() %>%
  
  # Residualize SG score w.r.t. PPL for each training vocabulary
  group_by(training_vocab) %>%
    # NB no need for training:ppl interaction, since we're within-group.
    mutate(resid.sg = resid(lm(sg_score ~ test_ppl))) %>%
  ungroup() %>%
  # Compute summary statistics across model--training--seed--corpus.
  group_by(model, training_vocab, corpus, seed) %>%
    summarise(resid.delta.mean = mean(resid.delta),
              resid.delta.sem = sd(resid.delta) / sqrt(length(resid.delta)),
              resid.sg.mean = mean(resid.sg),
              resid.sg.sem = sd(resid.sg) / sqrt(length(resid.sg)))
# Now plot residual vs SG
d_resid %>%
  #filter(corpus != "bnc-brown") %>%
  ggplot(aes(x=resid.sg.mean, y=resid.delta.mean)) +
    geom_errorbar(aes(xmin=resid.sg.mean - resid.sg.sem,
                      xmax=resid.sg.mean + resid.sg.sem,
                      ymin=resid.delta.mean - resid.delta.sem,
                      ymax=resid.delta.mean + resid.delta.sem), alpha=0.3) +
    geom_smooth(method="lm", se=T) +
    geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model, color=training_vocab)) +
    ylab(paste("Residual", metric)) +
    xlab("Residual Syntax Generalization Score") +
    ggtitle("Syntactic Generalization vs. Predictive Power") +
    scale_color_manual(values = c("bllip-lg"="#440154FF",
                                  "bllip-md"="#39568CFF",
                                  "bllip-sm"="#1F968BFF",
                                  "bllip-xs"="#73D055FF",
                                  "gptbpe"="#888888")) +
    facet_grid(.~corpus, scales="free") +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "right")
ggsave("../images/cuny2020/ppl_sg.png",height=4.5,width=11)
```


```{r Stepwise regression}
do_stepwise_regression = function(cur_corpus) {
  regression_data = model_deltas %>%
    filter(corpus == cur_corpus)
  
  print("----------------------")
  print(cur_corpus)
  
  lm1 = lm(delta_test_mean ~ training_vocab:test_ppl, data = regression_data)
  lm2 = lm(delta_test_mean ~ training_vocab:test_ppl + sg_score, data = regression_data)
  print(anova(lm1, lm2))
  summary(lm2)
}
#do_stepwise_regression("bnc-brown")
do_stepwise_regression("dundee")
do_stepwise_regression("natural-stories")
```

## Predictive power and perplexity

```{r}
model_deltas %>%
  ggplot(aes(x=test_ppl, y=delta_test_mean, color=training_vocab)) +
    geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), alpha=0.4) +
    #geom_smooth(method="lm", se=F) +
    geom_point(stat="identity", position="dodge", alpha=1, size=4, aes(shape=model)) +
    ylab(metric) +
    xlab("Test Perplexity") +
    #coord_cartesian(ylim = c(1, 16)) +
    ggtitle("Test Perplexity vs. Predictive Power") +
    scale_color_manual(values = c("bllip-lg"="#440154FF",
                                  "bllip-md"="#39568CFF",
                                  "bllip-sm"="#1F968BFF",
                                  "bllip-xs"="#73D055FF",
                                  "gptbpe"="#888888")) +
    facet_grid(~corpus, scales="free") +
    #coord_cartesian(ylim = c(0, 150)) +
    theme(axis.text=element_text(size=12),
          strip.text.x = element_text(size=12),
          legend.text=element_text(size=12),
          axis.title=element_text(size=12),
          legend.position = "right")
ggsave("../images/cuny2020/ppl_loglik.png",height=4.5,width=11)

```


## Effect of training data size

```{r On predictive power}
model_deltas %>%
  mutate(train_size = log(train_size)) %>%
  ggplot(aes(x=train_size, y=delta_test_mean, color=model)) +
    geom_errorbar(aes(ymin=delta_test_mean-delta_test_sem, ymax=delta_test_mean+delta_test_sem), width = 0.1) +
    geom_smooth(method="lm", se=T, alpha=0.5) +
    geom_point(stat="identity", position="dodge", alpha=1, size=3) +
    ylab(metric) +
    xlab("Log Million Training Tokens") +
    ggtitle("Training Size vs. Predictive Power") +
    facet_grid(corpus~model, scales="free") +
    #scale_color_manual(values = c("#A42EF1", "#3894C8")) +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "bottom")
#ggsave("./cogsci_images/training_loglik.png",height=5,width=6)
```


```{r On SG score}
model_deltas %>%
  mutate(train_size = log(train_size)) %>%
  ggplot(aes(x=train_size, y=sg_score, color=model)) +
    geom_smooth(method="lm", se=T, alpha=0.5) +
    geom_point(stat="identity", position="dodge", alpha=1, size=3) +
    ylab("SG SCore") +
    xlab("Log Million Training Tokens") +
    ggtitle("Training Size vs. Syntactic Generalization") +
    #scale_color_manual(values = c("#A42EF1", "#3894C8")) +
    facet_grid(~model, scales="free") +
    theme(axis.text=element_text(size=14),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "bottom")
#ggsave("./cogsci_images/training_sg.png",height=5,width=6)
```

## Smith & Levy reproduction

```{r}
all_data %>%
  #filter(surprisal < 15, surprisal > 0) %>%
  mutate(bpe=str_detect(training, "bpe"),
         training_source=str_replace(training, "-gptbpe", "")) %>% 
  ggplot(aes(x=surprisal, y=psychometric, color=training_source, linetype=bpe)) +
    stat_smooth(se=T, alpha=0.5) +
    #geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
    #geom_point(stat="identity", position="dodge", alpha=1, size=3) +
    ylab("Processing Time (ms)") +
    xlab("Surprisal (bits)") +
    ggtitle("Surprisal vs. Reading Time / Gaze Duration") +
    facet_grid(corpus~model, scales = "free") +
    # scale_color_manual(values = c("bllip-lg"="#440154FF",
    #                           "bllip-md"="#39568CFF",
    #                           "bllip-sm"="#1F968BFF",
    #                           "bllip-xs"="#73D055FF",
    #                           "bllip-lg-gptbpe"="#888888",
    #                           "bllip-md-gptbpe"="#888888",
    #                           "bllip-sm-gptbpe"="#888888",
    #                           "bllip-xs-gptbpe"="#888888")) +
    theme(axis.text=element_text(size=14),
          axis.text.y = element_text(size = 10),
          strip.text.x = element_text(size=14),
          legend.text=element_text(size=14),
          axis.title=element_text(size=18),
          legend.position = "right")
ggsave("../images/cuny2020/surp_corr.png",height=4.5,width=12)
```

### Investigate vanilla

```{r}
all_data %>%
  #filter(surprisal < 15, surprisal > 0) %>%
  filter(model == "vanilla") %>% 
  ggplot(aes(x=surprisal, y=psychometric)) +
    #stat_smooth(se=T, alpha=0.5) +
    #geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
    geom_point(alpha=0.1) + #stat="identity", position="dodge", alpha=1, size=3) +
    ylab("Processing Time (ms)") +
    xlab("Surprisal (bits)") +
    ggtitle("Surprisal vs. Reading Time / Gaze Duration: Vanilla") +
    facet_grid(corpus~training, scales = "free")
    # scale_color_manual(values = c("bllip-lg"="#440154FF",
    #                           "bllip-md"="#39568CFF",
    #                           "bllip-sm"="#1F968BFF",
    #                           "bllip-xs"="#73D055FF",
    #                           "bllip-lg-gptbpe"="#888888",
    #                           "bllip-md-gptbpe"="#888888",
    #                           "bllip-sm-gptbpe"="#888888",
    #                           "bllip-xs-gptbpe"="#888888"))
```

```{r Tony Blair is the source of that right-side dip}
all_data %>% 
  filter(corpus == "dundee", model == "vanilla", training == "bllip-lg", surprisal > 20, psychometric < 300)
```

```{r}
print(full_residuals %>% filter(corpus == "dundee", model == "vanilla", training == "bllip-lg") %>% arrange(desc(resid)))
full_residuals %>% filter(corpus == "dundee", model == "vanilla", training == "bllip-lg") %>% arrange(desc(resid)) %>% filter(resid > 150) %>% 
  ggplot(aes(x=surprisal)) + geom_density()
```


### Investigate RNNG

```{r}
all_data %>%
  #filter(surprisal < 15, surprisal > 0) %>%
  filter(model == "rnng") %>% 
  ggplot(aes(x=surprisal, y=psychometric)) +
    #stat_smooth(se=T, alpha=0.5) +
    #geom_errorbar(color="black", width=.2, position=position_dodge(width=.9), alpha=0.3) +
    geom_point(alpha=0.1) + #stat="identity", position="dodge", alpha=1, size=3) +
    ylab("Processing Time (ms)") +
    xlab("Surprisal (bits)") +
    ggtitle("Surprisal vs. Reading Time / Gaze Duration: RNNG") +
    facet_grid(corpus~training, scales = "free")
```

```{r Tony Blair is the source of that right-side dip for RNNG too}
all_data %>% 
  filter(corpus == "dundee", model == "rnng", training == "bllip-lg", surprisal > 20, psychometric < 300)
```

```{r}
print(full_residuals %>% filter(corpus == "dundee", model == "rnng", training == "bllip-lg") %>% arrange(desc(resid)))
full_residuals %>% filter(corpus == "dundee", model == "rnng", training == "bllip-lg") %>% arrange(desc(resid)) %>% filter(resid > 150) %>% 
  ggplot(aes(x=surprisal)) + geom_density()
```

### Investigate ngram vs vanilla

```{r}
ngram_resids = full_residuals %>% filter(model == "5gram", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
vanilla_resids = full_residuals %>% filter(model == "vanilla", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
resids_joined = ngram_resids %>% left_join(vanilla_resids, by=c("corpus", "code"), suffix=c(".ngram", ".vanilla"))

resids_joined %>% 
  ggplot(aes(x=resid.ngram, y=resid.vanilla)) + geom_point() + geom_abline(slope=1, color="red") +
  facet_grid(~corpus)

resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.ngram - resid.vanilla)) %>% 
  ggplot(aes(x=resid_abs_diff)) + geom_density() +
  facet_grid(~corpus)

resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.ngram - resid.vanilla)) %>% 
  ggplot(aes(x=freq.ngram, y=resid_abs_diff)) + geom_point(alpha=0.1) + geom_smooth()
```

### Investigate gptbpe vs vanilla

```{r}
gpt_resids = full_residuals %>% filter(model == "gpt2", training == "bllip-sm-gptbpe") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
vanilla_resids = full_residuals %>% filter(model == "vanilla", training == "bllip-sm") %>% group_by(corpus, code) %>% summarise(freq=mean(freq), psychometric=mean(psychometric), resid=mean(resid))
resids_joined = gpt_resids %>% left_join(vanilla_resids, by=c("corpus", "code"), suffix=c(".gpt", ".vanilla"))

resids_joined %>% 
  ggplot(aes(x=resid.gpt, y=resid.vanilla)) + geom_point() + geom_abline(slope=1, color="red") +
  facet_grid(~corpus)

resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.gpt - resid.vanilla)) %>% 
  ggplot(aes(x=resid_abs_diff)) + geom_density() +
  facet_grid(~corpus)

resids_joined %>% 
  mutate(resid_abs_diff=abs(resid.gpt - resid.vanilla)) %>% 
  ggplot(aes(x=freq.gpt, y=resid_abs_diff)) + geom_point(alpha=0.1) + geom_smooth()
```

### Investigate residuals overall

```{r}
full_residuals %>% right_join(baseline_residuals, by=c("corpus", "code", "model", "training", "seed"), suffix=c(".full", ".baseline")) %>% 
  mutate(resid_delta=resid.full - resid.baseline,
         training_source=as.factor(str_replace(training, "-gptbpe", "")),
         bpe=str_detect(training, "gptbpe")) %>% 
  ggplot(aes(x=surprisal.full, y=resid_delta, color=training)) +
    facet_grid(model~corpus) +
    geom_point(alpha=0.1, size=0.5)
```

```{r}
language_model_data %>% filter(model == "gpt2")
```

